{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Discovering Important Words for Sentiments With NormLIME"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook loads the pretrained Bi-LSTM model following [PaddleNLP TextClassification](https://github.com/PaddlePaddle/models/tree/release/2.0-beta/PaddleNLP/examples/text_classification/rnn) and performs sentiment analysis on reviews data. The full official PaddlePaddle sentiment classification tutorial can be found [here](https://github.com/PaddlePaddle/models/tree/release/2.0-beta/PaddleNLP/examples/text_classification). \n",
    "\n",
    "NormLIME method aggregates local models into global and class-specific interpretations. It is effective at recognizing important features. In this notebook, we use NormLIME method, specifically `NormLIMENLPInterpreter`, to discover the words that contribute the most to positive and negative sentiment predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import paddle\n",
    "import numpy as np\n",
    "import interpretdl as it\n",
    "import jieba"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings \n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load the word dict and specify the pretrained model path. Define the `unk_id` to be the word id for *\\[UNK\\]* token. Other possible choices include empty token *\\\"\\\"* and *\\[PAD\\]* token.\n",
    "\n",
    "To obtain the pretrained weights, please train a bilstm model following the [tutorial](https://github.com/PaddlePaddle/models/tree/release/2.0-beta/PaddleNLP/examples/text_classification/rnn) and specify the final `.pdparams` file position in `PARAMS_PATH` below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_vocab(vocab_file):\n",
    "    \"\"\"Loads a vocabulary file into a dictionary.\"\"\"\n",
    "    vocab = {}\n",
    "    with open(vocab_file, \"r\", encoding=\"utf-8\") as reader:\n",
    "        tokens = reader.readlines()\n",
    "    for index, token in enumerate(tokens):\n",
    "        token = token.rstrip(\"\\n\").split(\"\\t\")[0]\n",
    "        vocab[token] = index\n",
    "    return vocab\n",
    "\n",
    "PARAMS_PATH = \"assets/final.pdparams\"\n",
    "VOCAB_PATH = \"assets/senta_word_dict.txt\"\n",
    "\n",
    "vocab = load_vocab(VOCAB_PATH)\n",
    "unk_id = vocab['[UNK]']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialize the BiLSTM model using **paddlenlp.models** and load pretrained weights."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import paddlenlp as ppnlp\n",
    "model = ppnlp.models.Senta(\n",
    "        network='bilstm',\n",
    "        vocab_size=len(vocab),\n",
    "        num_classes=2)\n",
    "\n",
    "state_dict = paddle.load(PARAMS_PATH)\n",
    "model.set_dict(state_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define a preprocessing function that takes in **a raw string** and outputs the model inputs that can be fed into paddle_model.\n",
    "\n",
    "In this case, the raw string is splitted and mapped to word ids. *texts* is a list of lists, where each list contains a sequence of padded word ids. *seq_lens* is a list that contains the sequence length of each unpadded word ids in *texts*. \n",
    "\n",
    "Since the input data is a single raw string. Both *texts* and *seq_lens* has length 1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess_fn(text):\n",
    "    texts = []\n",
    "    seq_lens = []\n",
    "\n",
    "    tokens = \" \".join(jieba.cut(text)).split(' ')\n",
    "    ids = []\n",
    "    unk_id = vocab.get('[UNK]', None)\n",
    "    for token in tokens:\n",
    "        wid = vocab.get(token, unk_id)\n",
    "        if wid:\n",
    "            ids.append(wid)\n",
    "    texts.append(ids)\n",
    "    seq_lens.append(len(ids))\n",
    "\n",
    "    pad_token_id = 0\n",
    "    max_seq_len = max(seq_lens)\n",
    "\n",
    "    texts = paddle.to_tensor(texts)\n",
    "    seq_lens = paddle.to_tensor(seq_lens)\n",
    "    return texts, seq_lens"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The dataset we'll be using is **ChnSentiCrop** dataset from paddlenlp. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install paddlenlp==2.0.0b"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use the first 1200 samples in the training set as our data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total of 1200 sentences\n"
     ]
    }
   ],
   "source": [
    "from paddlenlp.datasets import ChnSentiCorp\n",
    "\n",
    "train_ds = ChnSentiCorp.get_datasets(['train'])\n",
    "data = [d[0] for d in list(train_ds)[:1200]]\n",
    "print('total of %d sentences' % len(data))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialize the `NormLIMENLPInterpreter`. We save the temporary results into a *.npz* file so that we don't have to run the whole process again if we want to rerun the same dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "normlime = it.NormLIMENLPInterpreter(\n",
    "    model, temp_data_file='assets/all_lime_weights_nlp.npz')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Begin `interpret`ing the whole dataset. This may take some time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/1200 [00:00<?, ?it/s]Building prefix dict from the default dictionary ...\n",
      "Loading model from cache /tmp/jieba.cache\n",
      "Loading model cost 0.842 seconds.\n",
      "Prefix dict has been built successfully.\n",
      "100%|██████████| 1200/1200 [02:21<00:00,  8.45it/s]\n"
     ]
    }
   ],
   "source": [
    "normlime_weights = normlime.interpret(\n",
    "    data,\n",
    "    preprocess_fn,\n",
    "    unk_id=unk_id,\n",
    "    pad_id=0,\n",
    "    num_samples=500,\n",
    "    batch_size=50)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the cells below, we print the words with top 20 largest weights for positive and negative sentiments. Only words that appear at least 5 times are included."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>word</th>\n",
       "      <th>weight</th>\n",
       "      <th>frequency</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>小巧</td>\n",
       "      <td>0.066120</td>\n",
       "      <td>12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>不错</td>\n",
       "      <td>0.057010</td>\n",
       "      <td>290</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>好书</td>\n",
       "      <td>0.036362</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>性价比</td>\n",
       "      <td>0.035803</td>\n",
       "      <td>42</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>干净</td>\n",
       "      <td>0.033354</td>\n",
       "      <td>27</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>每次</td>\n",
       "      <td>0.033267</td>\n",
       "      <td>13</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>海边</td>\n",
       "      <td>0.028790</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>热情</td>\n",
       "      <td>0.027850</td>\n",
       "      <td>20</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>一句</td>\n",
       "      <td>0.027751</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>很漂亮</td>\n",
       "      <td>0.027581</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>总的来说</td>\n",
       "      <td>0.026612</td>\n",
       "      <td>10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>精致</td>\n",
       "      <td>0.026512</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>清晰</td>\n",
       "      <td>0.025618</td>\n",
       "      <td>10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>大方</td>\n",
       "      <td>0.024577</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>品牌</td>\n",
       "      <td>0.024470</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>拥有</td>\n",
       "      <td>0.023639</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>漂亮</td>\n",
       "      <td>0.022731</td>\n",
       "      <td>26</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>超值</td>\n",
       "      <td>0.022219</td>\n",
       "      <td>12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>很快</td>\n",
       "      <td>0.022004</td>\n",
       "      <td>25</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>外观</td>\n",
       "      <td>0.020733</td>\n",
       "      <td>51</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    word    weight  frequency\n",
       "0     小巧  0.066120         12\n",
       "1     不错  0.057010        290\n",
       "2     好书  0.036362         17\n",
       "3    性价比  0.035803         42\n",
       "4     干净  0.033354         27\n",
       "5     每次  0.033267         13\n",
       "6     海边  0.028790          7\n",
       "7     热情  0.027850         20\n",
       "8     一句  0.027751          5\n",
       "9    很漂亮  0.027581          9\n",
       "10  总的来说  0.026612         10\n",
       "11    精致  0.026512          6\n",
       "12    清晰  0.025618         10\n",
       "13    大方  0.024577          6\n",
       "14    品牌  0.024470          9\n",
       "15    拥有  0.023639          6\n",
       "16    漂亮  0.022731         26\n",
       "17    超值  0.022219         12\n",
       "18    很快  0.022004         25\n",
       "19    外观  0.020733         51"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "id2word = dict(zip(vocab.values(), vocab.keys()))\n",
    "# Positive \n",
    "temp = {\n",
    "    id2word[wid]: normlime_weights[1][wid]\n",
    "    for wid in normlime_weights[1]\n",
    "}\n",
    "W = [(word, weight[0], weight[1]) for word, weight in temp.items() if  weight[1] >= 5]\n",
    "pd.DataFrame(data = sorted(W, key=lambda x: -x[1])[:20], columns = ['word', 'weight', 'frequency'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>word</th>\n",
       "      <th>weight</th>\n",
       "      <th>frequency</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>很差</td>\n",
       "      <td>0.002387</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>是不是</td>\n",
       "      <td>0.000430</td>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>不好</td>\n",
       "      <td>0.000397</td>\n",
       "      <td>35</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>不</td>\n",
       "      <td>0.000387</td>\n",
       "      <td>220</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>不会</td>\n",
       "      <td>0.000318</td>\n",
       "      <td>18</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>贵</td>\n",
       "      <td>0.000273</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>隔音</td>\n",
       "      <td>0.000267</td>\n",
       "      <td>14</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>不如</td>\n",
       "      <td>0.000247</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>不是</td>\n",
       "      <td>0.000245</td>\n",
       "      <td>37</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>只能</td>\n",
       "      <td>0.000242</td>\n",
       "      <td>26</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>电话</td>\n",
       "      <td>0.000242</td>\n",
       "      <td>13</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>或</td>\n",
       "      <td>0.000221</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>根本</td>\n",
       "      <td>0.000221</td>\n",
       "      <td>10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>光驱</td>\n",
       "      <td>0.000216</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>发现</td>\n",
       "      <td>0.000211</td>\n",
       "      <td>37</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>失望</td>\n",
       "      <td>0.000194</td>\n",
       "      <td>23</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>唯一</td>\n",
       "      <td>0.000185</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>太</td>\n",
       "      <td>0.000183</td>\n",
       "      <td>34</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>指纹</td>\n",
       "      <td>0.000180</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>差</td>\n",
       "      <td>0.000178</td>\n",
       "      <td>37</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   word    weight  frequency\n",
       "0    很差  0.002387          9\n",
       "1   是不是  0.000430          8\n",
       "2    不好  0.000397         35\n",
       "3     不  0.000387        220\n",
       "4    不会  0.000318         18\n",
       "5     贵  0.000273          6\n",
       "6    隔音  0.000267         14\n",
       "7    不如  0.000247          9\n",
       "8    不是  0.000245         37\n",
       "9    只能  0.000242         26\n",
       "10   电话  0.000242         13\n",
       "11    或  0.000221          6\n",
       "12   根本  0.000221         10\n",
       "13   光驱  0.000216          6\n",
       "14   发现  0.000211         37\n",
       "15   失望  0.000194         23\n",
       "16   唯一  0.000185          5\n",
       "17    太  0.000183         34\n",
       "18   指纹  0.000180          7\n",
       "19    差  0.000178         37"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Negative\n",
    "temp = {\n",
    "    id2word[wid]: normlime_weights[0][wid]\n",
    "    for wid in normlime_weights[0]\n",
    "}\n",
    "W = [(word, weight[0], weight[1]) for word, weight in temp.items() if  weight[1] >= 5]\n",
    "pd.DataFrame(data = sorted(W, key=lambda x: -x[1])[:20], columns = ['word', 'weight', 'frequency'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "firstEnv",
   "language": "python",
   "name": "firstenv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
